package autowire

import (
	"errors"
	"fmt"
	"log"
	"reflect"
	"sync"
)

var ErrBeanNotFound = errors.New("can not find bean")

// BeanFactory manages beans.
type BeanFactory interface {
	GetBeanByName(beanName string) (bean any, found bool)
	GetBeanByType(beanType reflect.Type) (bean any, err error)
	GetBeansByType(beanType reflect.Type) (beans []any)
	RegisterBeanByName(beanName string, bean any) error
	RegisterBeans(beans ...any)
	RegisterBeanCreatorByType(beanType reflect.Type, creator BeanCreator) error
	RegisterBeanCreatorByName(beanName string, creator BeanCreator) error
	// Autowire will wire all beans.
	Autowire() error
	// AutowireBean will wire the specific bean.
	AutowireBean(bean any) error
}

func NewBeanFactory() BeanFactory {
	bf := &beanFactory{
		lock:              sync.RWMutex{},
		beanByName:        map[string]any{},
		beanByType:        map[reflect.Type][]any{},
		beanByTypeCreator: map[reflect.Type]BeanCreator{},
		beanByNameCreator: map[string]BeanCreator{},
	}
	return bf
}

// BeanCreator creates bean.
type BeanCreator func(bf BeanFactory) any

type beanFactory struct {
	lock              sync.RWMutex
	beanByName        map[string]any
	beanByType        map[reflect.Type][]any
	beanByTypeCreator map[reflect.Type]BeanCreator
	beanByNameCreator map[string]BeanCreator
}

func (bf *beanFactory) GetBeanByName(beanName string) (bean any, found bool) {
	bf.lock.RLock()
	bean, found = bf.beanByName[beanName]
	bf.lock.RUnlock()
	if found {
		return
	}

	bf.lock.RLock()
	creator, found := bf.beanByNameCreator[beanName]
	bf.lock.RUnlock()

	if found {
		bean = creator(bf)
	}
	return
}

func (bf *beanFactory) GetBeanByType(beanType reflect.Type) (bean any, err error) {
	bf.lock.RLock()
	defer bf.lock.RUnlock()
	beans := bf.beanByType[beanType]
	if len(beans) == 1 {
		bean = beans[0]
		err = nil
		return
	}
	if len(beans) > 1 {
		err = fmt.Errorf("find more than one bean by specific type:%s", beanType.Name())
		return
	}
	beans = bf.GetBeansByType(beanType)
	if len(beans) == 1 {
		bean = beans[0]
		err = nil
		return
	}
	if len(beans) > 1 {
		err = fmt.Errorf("find more than one bean by type:%s", beanType.Name())
		return
	}
	err = ErrBeanNotFound
	return
}

func (bf *beanFactory) GetBeansByType(beanType reflect.Type) (beans []any) {
	typeIsInterface := beanType.Kind() == reflect.Interface

	bf.lock.RLock()
	for t, bs := range bf.beanByType {
		if t == beanType || (typeIsInterface && t.Implements(beanType)) {
			beans = append(beans, bs...)
		}
	}

	creator, found := bf.beanByTypeCreator[beanType]
	bf.lock.RUnlock()

	if found {
		bean := creator(bf)
		beans = append(beans, bean)
		return
	}

	if typeIsInterface {
		var creators []BeanCreator
		bf.lock.RLock()
		for t, c := range bf.beanByTypeCreator {
			if t.Implements(beanType) {
				creators = append(creators, c)
			}
		}
		bf.lock.RUnlock()
		for i := 0; i < len(creators); i++ {
			beans = append(beans, creators[i](bf))
		}
	}
	return
}

func (bf *beanFactory) RegisterBeanByName(beanName string, bean any) error {
	bf.lock.Lock()
	defer bf.lock.Unlock()

	if _, found := bf.beanByName[beanName]; found {
		return fmt.Errorf("bean name[%s] duplicated", beanName)
	}
	bf.beanByName[beanName] = bean
	beanType := reflect.TypeOf(bean)

	bf.beanByType[beanType] = append(bf.beanByType[beanType], bean)
	return nil
}

func (bf *beanFactory) RegisterBeans(beans ...any) {
	bf.lock.Lock()
	defer bf.lock.Unlock()
	for i := 0; i < len(beans); i++ {
		beanType := reflect.TypeOf(beans[i])
		bf.beanByType[beanType] = append(bf.beanByType[beanType], beans[i])
	}
}

func (bf *beanFactory) RegisterBeanCreatorByType(beanType reflect.Type, creator BeanCreator) error {
	bf.lock.Lock()
	defer bf.lock.Unlock()

	if _, found := bf.beanByTypeCreator[beanType]; found {
		return fmt.Errorf("bean creator[beanType:%s] duplicated", beanType.Name())
	}
	bf.beanByTypeCreator[beanType] = creator
	return nil
}

func (bf *beanFactory) RegisterBeanCreatorByName(beanName string, creator BeanCreator) error {
	bf.lock.Lock()
	defer bf.lock.Unlock()

	if _, found := bf.beanByNameCreator[beanName]; found {
		return fmt.Errorf("bean creator[beanName:%s] duplicated", beanName)
	}
	bf.beanByNameCreator[beanName] = creator
	return nil
}

func (bf *beanFactory) Autowire() error {

	for _, beans := range bf.beanByType {
		for _, bean := range beans {
			err := bf.AutowireBean(bean)
			if err != nil {
				return err
			}
		}
	}
	return nil
}

func (bf *beanFactory) AutowireBean(bean any) error {

	rv := reflect.ValueOf(bean)
	if rv.Kind() != reflect.Pointer {
		return fmt.Errorf("bean[%s] must be a pointer", bean)
	}

	rve := rv.Elem()
	beanType := reflect.TypeOf(bean)
	for i := 0; i < rve.NumField(); i++ {
		fv := rve.Field(i)
		if fv.CanSet() && fv.IsZero() {

			field := beanType.Elem().Field(i)
			beanName := field.Tag.Get("bean")
			if len(beanName) > 0 {
				// wire by name
				b, found := bf.GetBeanByName(beanName)
				if found {
					fv.Set(reflect.ValueOf(b))
					continue
				}
			}

			// wire by type
			fk := fv.Kind()
			if fk == reflect.Interface || fk == reflect.Pointer || fk == reflect.Struct {
				ft := fv.Type()
				b, err := bf.GetBeanByType(ft)
				if err == ErrBeanNotFound {
					log.Printf(`can not find bean by type:%v, when autowire bean:%v, ignore the field:"%s"`, ft, beanType, field.Name)
					continue
				}
				if err == nil {
					fv.Set(reflect.ValueOf(b))
				} else {
					return err
				}
			}
		}
	}

	return nil
}

func GetReflectType[T any]() reflect.Type {
	ret := reflect.TypeOf((*T)(nil)).Elem()
	return ret
}

func GetBeanByType[T any](bf BeanFactory) (bean T, err error) {
	b, err := bf.GetBeanByType(GetReflectType[T]())
	if err != nil {
		return
	}
	bean = b.(T)
	return
}

func GetBeansByType[T any](bf BeanFactory) (beans []T) {
	bs := bf.GetBeansByType(GetReflectType[T]())
	beans = make([]T, len(bs))
	for i := 0; i < len(bs); i++ {
		beans[i] = bs[i].(T)
	}
	return
}

func RegisterBeanCreatorByType[T any](bf BeanFactory, creator BeanCreator) error {
	return bf.RegisterBeanCreatorByType(GetReflectType[T](), creator)
}
